sup-extra HDU-4609 3-idiots FFT
#HDU-4609 3-idiots FFT
题意就是在一组数中找三个数,计算能够组成三角形的概率是多少。这里就体现了FFT的用法其二,计算a+b的可能情况时转换成多项式乘法的运用。
由于x^a * x^b = x^(a+b),所以我们只要将数的值视作幂值,数的个数视作系数,就能将加法运算变为乘法运算了。这就是解这题的核心部分了。另一个难点在于如何处理构成三角形的条件a+b>c的呢?我们需要利用容斥原理把不能构成的情况和c不是最大的数的情况列出来然后减去。不能构成的情况有:
-
a+b时同一个数加了两次。解决:遍历数组减去num[a[i]+a[i]]
-
a+b与b+a重复统计。解决:减掉了上面的情况后将num/=2;
-
a+b后与a比较。此时多加了n-1,在结果上减掉。
-
a+b 中a>c,b<c. 当c为第i个数时,此时多加了(n-i-1)*i个数
-
a+b 中a>c,b>c. 当c为第i个数时,此时多加了 (n-i-1)*(n-i-2)/2个数
所以只要减掉这些情况就能得到正确的值了。
除此之外还遇到了比如FFT的最大数目不对啊,要用long long 啊,遍历的时候越界啊之类的智障错误……
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <map>
#include <cmath>
#define maxl (1<<17)
#define pi 3.141592653589793238462643383
using namespace std;
struct complex
{
double re,im;
complex (double r=0.0,double i=0.0){re=r;im=i;}
}a[maxl*2],w[2][maxl*2];
complex operator +(const complex&x,const complex&y)
{
return complex(x.re+y.re,x.im+y.im);
}
complex operator -(const complex&x,const complex&y)
{
return complex(x.re-y.re,x.im-y.im);
}
complex operator *(const complex&x,const complex&y)
{
return complex(x.re*y.re-x.im*y.im,x.im*y.re+x.re*y.im);
}
long long n,m,rev[maxl*2],sum[maxl*2],ori[maxl*2],num[maxl*2];
void init()//calculate reverse
{
for(int i=0;i<n;i++)
{
int coun=i,tmp=0;
for(int j=1;j<n;j<<=1,coun>>=1)
{
tmp<<=1;
tmp|=(coun&1);
}
rev[i]=tmp;
}
for(int i=0;i<n;i++)
{
w[0][i]=w[1][i]=complex(cos(2*pi*i/n),sin(2*pi*i/n));
w[1][i].im=-w[0][i].im;
}
}
void FFT(complex *a,int order)
{
complex x,y;
for(int i=0;i<n;i++)
{
if(i<rev[i])swap(a[i],a[rev[i]]);
}
for(int i=1;i<n;i<<=1)
{
for(int j=0,t=n/(i<<1);j<n;j+=i<<1)
for(int k=0,l=0;k<i;k++,l+=t)
{
x=w[order][l]*a[j+k+i];
y=a[j+k];
a[j+k]=y+x;
a[j+k+i]=y-x;
}
}
if(order)for(int i=0;i<n;i++) a[i].re/=n;
}
int main(void)
{
int t;
scanf("%d",&t);
memset(a,0,sizeof(a));
while(t--)
{
scanf("%lld",&m);
int most=0;
for(int i=0;i<m;i++)
{
scanf("%lld",&ori[i]);
a[ori[i]].re++;
if(most<ori[i])
most=ori[i];
}
for(n=1;n<=most;n<<=1);n<<=1;
init();
FFT(a,0);
for(int i=0;i<n;i++) a[i]=a[i]*a[i];
FFT(a,1);
for(int i=0;i<n;i++)
{
num[i]=(long long)(a[i].re+0.5);
//cout<<i<<' '<<num[i]<<endl;
}
for(int i=0;i<m;i++)
num[ori[i]+ori[i]]--;
for(int i=0;i<n;i++)
num[i]/=2;
sum[0]=0;
long long res=0;
for(int i=1;i<=n;i++)
{
sum[i]=sum[i-1]+num[i];//前缀和
}
/*for(int i=0;i<n;i++)
cout<<i<<' '<<num[i]<<' '<<sum[i]<<endl;*/
sort(ori,ori+m);
for(long long i=0;i<m;i++)
{
res+=sum[n-1]-sum[ori[i]];//总和
res-=(long long)(m-1-i)*i;//一大一小
res-=(long long)(m-1-i)*(m-2-i)/2;//两个大
res-=m-1;//自身
//cout<<sum[n-1]-sum[ori[i]]<<' '<<(m-1-i)*i<<(m-1-i)*(m-2-i)/2<<m-1<<endl;
}
long long tot=m*(m-1)*(m-2)/6;
printf("%.7lf\n",(double)(res*1.0/tot));
for(int i=0;i<n;i++)
a[i].re=a[i].im=0;
}
return 0;
}